{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from utils import *\n",
    "import tensorflow as tf\n",
    "from sklearn.cross_validation import train_test_split\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from unidecode import unidecode\n",
    "import re\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style>\n",
       "    .dataframe thead tr:only-child th {\n",
       "        text-align: right;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: left;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>text</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Lebih-lebih lagi dengan  kemudahan internet da...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Positive</td>\n",
       "      <td>boleh memberi teguran kepada parti tetapi perl...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Adalah membingungkan mengapa masyarakat Cina b...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Positive</td>\n",
       "      <td>Kami menurunkan defisit daripada 6.7 peratus p...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Negative</td>\n",
       "      <td>Ini masalahnya. Bukan rakyat, tetapi sistem</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      label                                               text\n",
       "0  Negative  Lebih-lebih lagi dengan  kemudahan internet da...\n",
       "1  Positive  boleh memberi teguran kepada parti tetapi perl...\n",
       "2  Negative  Adalah membingungkan mengapa masyarakat Cina b...\n",
       "3  Positive  Kami menurunkan defisit daripada 6.7 peratus p...\n",
       "4  Negative        Ini masalahnya. Bukan rakyat, tetapi sistem"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('sentiment-news-bahasa-v5.csv')\n",
    "Y = LabelEncoder().fit_transform(df.label)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def textcleaning(string):\n",
    "    string = re.sub('http\\S+|www.\\S+', '',' '.join([i for i in string.split() if i.find('#')<0 and i.find('@')<0]))\n",
    "    string = unidecode(string).replace('.', '. ').replace(',', ', ')\n",
    "    string = re.sub('[^\\'\\\"A-Za-z\\- ]+', ' ', string)\n",
    "    return ' '.join([i for i in re.findall(\"[\\\\w']+|[;:\\-\\(\\)&.,!?\\\"]\", string) if len(i)>1]).lower()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(df.shape[0]):\n",
    "    df.iloc[i,1] = textcleaning(df.iloc[i,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('polarity-negative-translated.txt','r') as fopen:\n",
    "    texts = fopen.read().split('\\n')\n",
    "labels = [0] * len(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('polarity-positive-translated.txt','r') as fopen:\n",
    "    positive_texts = fopen.read().split('\\n')\n",
    "labels += [1] * len(positive_texts)\n",
    "texts += positive_texts\n",
    "texts += df.iloc[:,1].tolist()\n",
    "labels += Y.tolist()\n",
    "\n",
    "assert len(labels) == len(texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 18844\n",
      "Most common words [('yang', 14893), ('dan', 8177), ('tidak', 4579), ('untuk', 4023), ('dengan', 3349), ('filem', 3279)]\n",
      "Sample data [1614, 204, 5, 161, 218, 106, 301, 4, 78, 203] ['ringkas', 'bodoh', 'dan', 'membosankan', 'kanak-kanak', 'lelaki', 'remaja', 'yang', 'begitu', 'muda']\n"
     ]
    }
   ],
   "source": [
    "concat = ' '.join(texts).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_trainset = []\n",
    "for text in texts:\n",
    "    idx = []\n",
    "    for t in text.split():\n",
    "        try:\n",
    "            idx.append(dictionary[t])\n",
    "        except:\n",
    "            idx.append(3)\n",
    "    idx_trainset.append(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_ngram_set(input_list, ngram_value):\n",
    "    return set(zip(*[input_list[i:] for i in range(ngram_value)]))\n",
    "\n",
    "def build_ngram(x_train):\n",
    "    global max_features\n",
    "    ngram_set = set()\n",
    "    for input_list in tqdm(x_train, total=len(x_train), ncols=70):\n",
    "        for i in range(2, ngram_range + 1):\n",
    "            set_of_ngram = create_ngram_set(input_list, ngram_value=i)\n",
    "            ngram_set.update(set_of_ngram)\n",
    "    start_index = max_features + 1\n",
    "    token_indice = {v: k + start_index for k, v in enumerate(ngram_set)}\n",
    "    indice_token = {token_indice[k]: k for k in token_indice}\n",
    "\n",
    "    max_features = np.max(list(indice_token.keys())) + 1\n",
    "    return token_indice\n",
    "\n",
    "def add_ngram(sequences, token_indice):\n",
    "    new_sequences = []\n",
    "    for input_list in sequences:\n",
    "        new_list = input_list[:]\n",
    "        for ngram_value in range(2, ngram_range + 1):\n",
    "            for i in range(len(new_list) - ngram_value + 1):\n",
    "                ngram = tuple(new_list[i:i + ngram_value])\n",
    "                if ngram in token_indice:\n",
    "                    new_list.append(token_indice[ngram])\n",
    "        new_sequences.append(new_list)\n",
    "    return new_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "ngram_range = 2\n",
    "max_features = 20000\n",
    "maxlen = 80\n",
    "batch_size = 32\n",
    "embedded_size = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████| 14279/14279 [00:00<00:00, 214848.30it/s]\n"
     ]
    }
   ],
   "source": [
    "token_indice = build_ngram(idx_trainset)\n",
    "X = add_ngram(idx_trainset, token_indice)\n",
    "X = tf.keras.preprocessing.sequence.pad_sequences(X, maxlen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(X, \n",
    "                                                    labels,\n",
    "                                                    test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(self, embedded_size, dict_size, dimension_output, learning_rate):\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        self.logits = tf.identity(tf.layers.dense(tf.reduce_mean(encoder_embedded, 1), dimension_output),\n",
    "                                  name=\"logits\")\n",
    "        self.cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=self.logits,\n",
    "            labels=self.Y))\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        correct_pred = tf.equal(tf.argmax(self.logits, 1,output_type=tf.int32), self.Y)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'fast-text/model.ckpt'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(embedded_size,vocabulary_size+4,2,5e-4)\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver(tf.global_variables())\n",
    "saver.save(sess, \"fast-text/model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.582514\n",
      "time taken: 2.2845072746276855\n",
      "epoch: 0, training loss: 0.687419, training acc: 0.548718, valid loss: 0.680611, valid acc: 0.582514\n",
      "\n",
      "epoch: 1, pass acc: 0.582514, current acc: 0.630267\n",
      "time taken: 2.1589436531066895\n",
      "epoch: 1, training loss: 0.669211, training acc: 0.605513, valid loss: 0.666756, valid acc: 0.630267\n",
      "\n",
      "epoch: 2, pass acc: 0.630267, current acc: 0.661868\n",
      "time taken: 2.159257173538208\n",
      "epoch: 2, training loss: 0.647248, training acc: 0.656162, valid loss: 0.650173, valid acc: 0.661868\n",
      "\n",
      "epoch: 3, pass acc: 0.661868, current acc: 0.674508\n",
      "time taken: 2.153167963027954\n",
      "epoch: 3, training loss: 0.620276, training acc: 0.699263, valid loss: 0.631791, valid acc: 0.674508\n",
      "\n",
      "epoch: 4, pass acc: 0.674508, current acc: 0.682584\n",
      "time taken: 2.154153347015381\n",
      "epoch: 4, training loss: 0.590033, training acc: 0.731039, valid loss: 0.613892, valid acc: 0.682584\n",
      "\n",
      "epoch: 5, pass acc: 0.682584, current acc: 0.691011\n",
      "time taken: 2.1567935943603516\n",
      "epoch: 5, training loss: 0.559180, training acc: 0.757461, valid loss: 0.598171, valid acc: 0.691011\n",
      "\n",
      "epoch: 6, pass acc: 0.691011, current acc: 0.699438\n",
      "time taken: 2.1571340560913086\n",
      "epoch: 6, training loss: 0.529474, training acc: 0.776159, valid loss: 0.585134, valid acc: 0.699438\n",
      "\n",
      "epoch: 7, pass acc: 0.699438, current acc: 0.701545\n",
      "time taken: 2.140665054321289\n",
      "epoch: 7, training loss: 0.501551, training acc: 0.792310, valid loss: 0.574633, valid acc: 0.701545\n",
      "\n",
      "epoch: 8, pass acc: 0.701545, current acc: 0.704003\n",
      "time taken: 2.1551663875579834\n",
      "epoch: 8, training loss: 0.475495, training acc: 0.808287, valid loss: 0.566387, valid acc: 0.704003\n",
      "\n",
      "epoch: 9, pass acc: 0.704003, current acc: 0.712781\n",
      "time taken: 2.1510813236236572\n",
      "epoch: 9, training loss: 0.451247, training acc: 0.821278, valid loss: 0.560150, valid acc: 0.712781\n",
      "\n",
      "epoch: 10, pass acc: 0.712781, current acc: 0.714888\n",
      "time taken: 2.1647536754608154\n",
      "epoch: 10, training loss: 0.428711, training acc: 0.835586, valid loss: 0.555716, valid acc: 0.714888\n",
      "\n",
      "epoch: 11, pass acc: 0.714888, current acc: 0.716643\n",
      "time taken: 2.1564390659332275\n",
      "epoch: 11, training loss: 0.407780, training acc: 0.846998, valid loss: 0.552895, valid acc: 0.716643\n",
      "\n",
      "time taken: 2.1448323726654053\n",
      "epoch: 12, training loss: 0.388329, training acc: 0.856215, valid loss: 0.551509, valid acc: 0.715239\n",
      "\n",
      "time taken: 2.1546058654785156\n",
      "epoch: 13, training loss: 0.370225, training acc: 0.865081, valid loss: 0.551391, valid acc: 0.715941\n",
      "\n",
      "time taken: 2.1497790813446045\n",
      "epoch: 14, training loss: 0.353339, training acc: 0.871840, valid loss: 0.552389, valid acc: 0.714185\n",
      "\n",
      "time taken: 2.1799774169921875\n",
      "epoch: 15, training loss: 0.337551, training acc: 0.880091, valid loss: 0.554370, valid acc: 0.714185\n",
      "\n",
      "time taken: 2.1526472568511963\n",
      "epoch: 16, training loss: 0.322751, training acc: 0.888694, valid loss: 0.557216, valid acc: 0.713834\n",
      "\n",
      "time taken: 2.1555614471435547\n",
      "epoch: 17, training loss: 0.308845, training acc: 0.894575, valid loss: 0.560832, valid acc: 0.713483\n",
      "\n",
      "time taken: 2.155064105987549\n",
      "epoch: 18, training loss: 0.295752, training acc: 0.898964, valid loss: 0.565133, valid acc: 0.713834\n",
      "\n",
      "time taken: 2.1563830375671387\n",
      "epoch: 19, training loss: 0.283400, training acc: 0.903792, valid loss: 0.570052, valid acc: 0.712781\n",
      "\n",
      "time taken: 2.1473162174224854\n",
      "epoch: 20, training loss: 0.271728, training acc: 0.909849, valid loss: 0.575532, valid acc: 0.712079\n",
      "\n",
      "time taken: 2.1607401371002197\n",
      "epoch: 21, training loss: 0.260683, training acc: 0.913975, valid loss: 0.581524, valid acc: 0.710674\n",
      "\n",
      "break epoch:22\n",
      "\n"
     ]
    }
   ],
   "source": [
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 10, 0, 0, 0\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n'%(EPOCH))\n",
    "        break\n",
    "        \n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    for i in range(0, (len(train_X) // batch_size) * batch_size, batch_size):\n",
    "        acc, loss, _ = sess.run([model.accuracy, model.cost, model.optimizer], \n",
    "                           feed_dict = {model.X : train_X[i:i+batch_size], model.Y : train_Y[i:i+batch_size]})\n",
    "        train_loss += loss\n",
    "        train_acc += acc\n",
    "    \n",
    "    for i in range(0, (len(test_X) // batch_size) * batch_size, batch_size):\n",
    "        acc, loss = sess.run([model.accuracy, model.cost], \n",
    "                           feed_dict = {model.X : test_X[i:i+batch_size], model.Y : test_Y[i:i+batch_size]})\n",
    "        test_loss += loss\n",
    "        test_acc += acc\n",
    "    \n",
    "    train_loss /= (len(train_X) // batch_size)\n",
    "    train_acc /= (len(train_X) // batch_size)\n",
    "    test_loss /= (len(test_X) // batch_size)\n",
    "    test_acc /= (len(test_X) // batch_size)\n",
    "    \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print('epoch: %d, pass acc: %f, current acc: %f'%(EPOCH,CURRENT_ACC, test_acc))\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time()-lasttime)\n",
    "    print('epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'%(EPOCH,train_loss,\n",
    "                                                                                          train_acc,test_loss,\n",
    "                                                                                          test_acc))\n",
    "    EPOCH += 1\n",
    "    saver.save(sess, \"fast-text/model.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from fast-text/model.ckpt\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "   negative       0.70      0.68      0.69      1342\n",
      "   positive       0.72      0.74      0.73      1514\n",
      "\n",
      "avg / total       0.71      0.71      0.71      2856\n",
      "\n"
     ]
    }
   ],
   "source": [
    "saver.restore(sess, \"fast-text/model.ckpt\")\n",
    "logits = sess.run(model.logits, feed_dict={model.X:test_X})\n",
    "print(metrics.classification_report(test_Y, np.argmax(logits,1), target_names = ['negative','positive']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_idx(texts):\n",
    "    idx_trainset = []\n",
    "    for text in texts:\n",
    "        idx = []\n",
    "        for t in text.split():\n",
    "            try:\n",
    "                idx.append(dictionary[t])\n",
    "            except:\n",
    "                pass\n",
    "        idx_trainset.append(idx)\n",
    "    return idx_trainset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.9678054 , 0.03219458]], dtype=float32)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = 'kerajaan sebenarnya sangat bencikan rakyatnya, minyak naik dan segalanya'\n",
    "new_vector = add_ngram(to_idx([text]), token_indice)\n",
    "sess.run(tf.nn.softmax(model.logits), feed_dict={model.X:new_vector})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1.7911309e-05, 9.9998212e-01]], dtype=float32)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = 'kerajaan sebenarnya sangat sayangkan rakyatnya'\n",
    "new_vector = add_ngram(to_idx([text]), token_indice)\n",
    "sess.run(tf.nn.softmax(model.logits), feed_dict={model.X:new_vector})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.84641653, 0.15358353]], dtype=float32)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = 'kerajaan sebenarnya sangat sayangkan rakyatnya, tetapi sebenarnya benci'\n",
    "new_vector = add_ngram(to_idx([text]), token_indice)\n",
    "sess.run(tf.nn.softmax(model.logits), feed_dict={model.X:new_vector})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('fast-text-sentiment.json','w') as fopen:\n",
    "    fopen.write(json.dumps({'dictionary':dictionary,'reverse_dictionary':rev_dictionary}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "strings=','.join([n.name for n in tf.get_default_graph().as_graph_def().node if \"Variable\" in n.op or n.name.find('Placeholder') >= 0 or n.name.find('logits') == 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            \"directory: %s\" % model_dir)\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "    \n",
    "    absolute_model_dir = \"/\".join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + \"/frozen_model.pb\"\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph=tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(\",\")\n",
    "        ) \n",
    "        with tf.gfile.GFile(output_graph, \"wb\") as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print(\"%d ops in the final graph.\" % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from fast-text/model.ckpt\n",
      "INFO:tensorflow:Froze 11 variables.\n",
      "Converted 11 variables to const ops.\n",
      "22 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph(\"fast-text\", strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_graph(frozen_graph_filename):\n",
    "    with tf.gfile.GFile(frozen_graph_filename, \"rb\") as f:\n",
    "        graph_def = tf.GraphDef()\n",
    "        graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)\n",
    "    return graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2856, 2)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g=load_graph('fast-text/frozen_model.pb')\n",
    "x = g.get_tensor_by_name('import/Placeholder:0')\n",
    "logits = g.get_tensor_by_name('import/logits:0')\n",
    "test_sess = tf.InteractiveSession(graph=g)\n",
    "predicted = test_sess.run(logits,feed_dict={x:test_X})\n",
    "predicted.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('token-indice.pkl','wb') as fopen:\n",
    "    pickle.dump(token_indice, fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
